//
// Copyright Contributors to the MaterialX Project
// SPDX-License-Identifier: Apache-2.0
//

#include <MaterialXTest/External/Catch/catch.hpp>
#include <MaterialXTest/MaterialXGenShader/GenShaderUtil.h>

#include <MaterialXCore/Document.h>

#include <MaterialXFormat/File.h>
#include <MaterialXFormat/Util.h>

#include <MaterialXGenShader/ShaderStage.h>

#include <MaterialXGenOsl/OslShaderGenerator.h>
#include <MaterialXGenOsl/OslSyntax.h>

#include <MaterialXRenderOsl/OslRenderer.h>

namespace mx = MaterialX;

TEST_CASE("GenReference: OSL Reference", "[genreference]")
{
    mx::FileSearchPath searchPath = mx::getDefaultDataSearchPath();
    mx::DocumentPtr datalib = mx::createDocument();
    loadLibraries({ "libraries" }, searchPath, datalib);

    // Create renderer if requested.
    bool runCompileTest = !std::string(MATERIALX_OSL_BINARY_OSLC).empty();
    mx::OslRendererPtr oslRenderer = nullptr;
    if (runCompileTest)
    {
        oslRenderer = mx::OslRenderer::create();
        oslRenderer->setOslCompilerExecutable(MATERIALX_OSL_BINARY_OSLC);
        mx::FileSearchPath oslIncludePaths;
        mx::FilePath oslStandardIncludePath = mx::FilePath(MATERIALX_OSL_INCLUDE_PATH);
        if (!oslStandardIncludePath.isEmpty())
        {
            oslIncludePaths.append(oslStandardIncludePath);
        }
        // Add in library include path for compile testing as the includes added by the shader
        // generator itself are not added with absolute paths (specifically "mx_funcs.h").
        oslIncludePaths.append(searchPath.find("libraries/stdlib/genosl/include"));
        oslRenderer->setOslIncludePath(oslIncludePaths);
    }

    // Create shader generator.
    mx::ShaderGeneratorPtr generator = mx::OslShaderGenerator::create();

    // Register types from the library.
    generator->registerTypeDefs(datalib);

    mx::GenContext context(generator);
    context.getOptions().addUpstreamDependencies = false;
    context.registerSourceCodeSearchPath(searchPath);
    context.getOptions().fileTextureVerticalFlip = true;

    // Create output directory.
    mx::FilePath outputPath = searchPath.find("reference/osl");
    outputPath.getParentPath().createDirectory();
    outputPath.createDirectory();

    // Create log file.
    const mx::FilePath logPath("genosl_reference_generate_test.txt");
    std::ofstream logFile;
    logFile.open(logPath);

    // Generate reference shaders.
    bool failedGeneration = false;
    for (const mx::NodeDefPtr& nodedef : datalib->getNodeDefs())
    {
        // Determine the corresponding nodes for the nodedef
        std::string nodeName = nodedef->getQualifiedName(nodedef->getNodeString());
        mx::InterfaceElementPtr interface = nodedef->getImplementation(generator->getTarget());
        if (!interface)
        {
            logFile << "Skip generating reference for unimplemented node '" << nodeName << "'" << std::endl;
            continue;
        }

        // Enumerate available nodes for nodedef and create node instances
        for (const mx::OutputPtr& nodeOutput : nodedef->getOutputs())
        {
            
            for (const mx::InputPtr & nodeInput: nodedef->getInputs())
            {
                mx::NodePtr node = datalib->addNodeInstance(nodedef, nodeName + "_" + nodeOutput->getType() + "_" + nodeInput->getType());
                REQUIRE(node);

                const std::string filename = nodeName + ".osl";
                try
                {
                    mx::ShaderPtr shader = generator->generate(node->getName(), node, context);

                    std::ofstream file;
                    const std::string filepath = (outputPath / filename).asString();
                    file.open(filepath);
                    REQUIRE(file.is_open());
                    file << shader->getSourceCode();
                    file.close();

                    if (oslRenderer)
                    {
                        oslRenderer->compileOSL(filepath);
                    }
                }
                catch (mx::ExceptionRenderError& e)
                {
                    logFile << "Error compiling OSL reference for '" << nodeName << "' : " << std::endl;
                    logFile << e.what() << std::endl;
                    for (const std::string& error : e.errorLog())
                    {
                        logFile << error << std::endl;
                    }
                    failedGeneration = true;
                }
                catch (mx::Exception& e)
                {
                    logFile << "Error generating OSL reference for '" << nodeName << "' : " << std::endl;
                    logFile << e.what() << std::endl;
                    failedGeneration = true;
                }
                datalib->removeChild(node->getName());
            }
        }
    }

    logFile.close();

    CHECK(failedGeneration == false);
}
